import train
import torch
def get_device():
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print(f"使用设备: {device}")
    return device


if __name__ == '__main__':
    device = get_device()
    train.run(data='workersafe.yaml',   device=device)